from OGM_dataset import OGMMotionCompensatedDataset, OGMPlainDataset
from torch.utils.data import DataLoader



if __name__ == "__main__":
    dset_path = ""
    # train, val or test
    train_flag = "train"
    OGM_dataset = OGMMotionCompensatedDataset("test_datset",
                                              dset_path,
                                              train_flag)
    dloader = DataLoader(OGM_dataset, batch_size=32, shuffle=True, drop_last=True)

    for i, batch in enumerate(dloader):
        in_seq, out_seq = batch['hist_seq'], batch['future_seq']
        # train_loops here
        # TODO: typically you want only the first frame in the future sequence